import random

from torch.utils import data
from tqdm import tqdm

from knowledge_tracing.args import ARGS
from datasets.interaction import InteractionParser
from augmentations import random_deletion, random_insertion, skill_based_replacement\

import numpy as np

# Generic dataset for sequence
# Parser interface: locate files and supply items
# Different types of interfaces
# Features: Some might share, some might


class InteractionDataset(data.Dataset):
    '''
    users: either path to csv file or list of users
    seq_size: maximum length of sequence,
    is_training: is in training mode
    '''
    def __init__(self, iparser: InteractionParser, users, features, is_training: bool, aug_methods=[], fraction=None):
        self._iparser = iparser
        if type(users) is str:
            # read list of users from file path
            self._users = np.array([line.strip() for line in open(users, 'r')], dtype=np.string_)
        else:
            self._users = np.array(users, dtype=np.string_)
        if fraction is not None:
            self._users = self._users[:int(fraction * len(self._users))]
        self._features = features
        self._feature_names = [f[0].name for f in self._features]
        self._is_training = is_training

        # entries are of form (user1, s1, e1), (user2, s2, e2), ...
        print('Initializing dataset...')
        user_index_ranges = []
        for user_idx, user in enumerate(tqdm(self._users)):
            ninters = iparser.num_interactions(user.astype(str))
            # compute the index ranges for current user with `ninter` interactions
            if is_training:
                window_size = min(ARGS.seq_size, ninters)
                # range of i = [0, ninters - window_size]
                user_ranges_inside = [(user_idx, i, i + window_size) for i in range(ninters - window_size + 1)]
                user_index_ranges.extend(user_ranges_inside)
                user_ranges_front = [(user_idx, 0, i) for i in range(1, window_size) if random.random() < ARGS.augment_front]
                user_index_ranges.extend(user_ranges_front)
                user_ranges_back = [(user_idx, i, ninters) for i in range(ninters - window_size + 1, ninters) if random.random() < ARGS.augment_back]
                user_index_ranges.extend(user_ranges_back)
            else:
                user_ranges = ((user_idx, max(i - ARGS.seq_size, 0), i) for i in range(1, ninters + 1))
                user_index_ranges.extend(user_ranges)
        self._user_index_ranges = np.array(user_index_ranges, dtype=np.long)

        name = 'Train' if self._is_training else 'Test '
        print(
            f'{name} | # of users: {len(self._users)} | # of samples: {len(self._user_index_ranges)}'
        )

        self._aug_methods = aug_methods

    def __len__(self):
        return len(self._user_index_ranges)

    def _add_interaction_idx(self, inter):
        inter.interaction_idx = 2 * inter.item_idx - inter.is_correct
        return inter

    def __getitem__(self, item):
        """
        Return:
            dictionary of dictionary of tensors.
            possible keys: 'ori' + augmentations ('rep', 'ins', 'del')
        """
        user_idx, start_index, end_index = self._user_index_ranges[item]
        user = self._users[user_idx].astype(str)
        inters = self._iparser.parse_interactions(user, start_index, end_index)  # sequence of interactions
        # add interaction indices
        if 'interaction_idx' in self._feature_names:
            inters = [self._add_interaction_idx(inter) for inter in inters]
        # apply augmentations
        data = {}  # dictionary of dictionary
        data['ori'] = {f[0].name: f[0].tensorize(inters, ARGS.seq_size) for f in self._features}
        if self._is_training and self._aug_methods != []:
            for aug in self._aug_methods:
                if aug == 'del':
                    del_inters, not_del_idx = random_deletion(sequence=inters,
                                                              del_prob=ARGS.del_prob,
                                                              response=ARGS.del_response
                                                              )
                    data['del'] = {f[0].name: f[0].tensorize(del_inters, ARGS.seq_size) for f in self._features}
                    not_del_idx += [-1] * (ARGS.seq_size - len(not_del_idx))
                    data['del']['idx'] = np.array(not_del_idx, dtype=np.long)
                elif aug == 'ins':
                    ins_inters, not_ins_idx = random_insertion(sequence=inters,
                                                               ins_prob=ARGS.ins_prob,
                                                               response=ARGS.ins_response,
                                                               ins_type=ARGS.ins_type
                                                               )
                    data['ins'] = {f[0].name: f[0].tensorize(ins_inters, ARGS.seq_size) for f in self._features}
                    if ARGS.ins_kt_loss:
                        insert_locations = list(set(range(len(ins_inters))) - set(not_ins_idx))
                        for i in insert_locations:
                            data['ins']['loss_mask'][i] = False
                    not_ins_idx += [-1] * (ARGS.seq_size - len(not_ins_idx))
                    data['ins']['idx'] = np.array(not_ins_idx, dtype=np.long)
                elif aug == 'rep':
                    rep_inters, not_rep_idx = skill_based_replacement(sequence=inters,
                                                                      rep_prob=ARGS.rep_prob,
                                                                      rep_type=ARGS.rep_type,
                                                                      response=ARGS.rep_response
                                                                      )
                    data['rep'] = {f[0].name: f[0].tensorize(rep_inters, ARGS.seq_size) for f in self._features}
                    not_rep_idx += [-1] * (ARGS.seq_size - len(not_rep_idx))
                    data['rep']['idx'] = np.array(not_rep_idx, dtype=np.long)
                else:
                    raise NotImplementedError

        return data
